{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/model_utils.py:295: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import xlnet\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tqdm import tqdm\n",
    "import model_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sentencepiece as spm\n",
    "from prepro_utils import preprocess_text, encode_ids\n",
    "\n",
    "sp_model = spm.SentencePieceProcessor()\n",
    "sp_model.Load('sp10m.cased.v9.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:63: The name tf.gfile.Open is deprecated. Please use tf.io.gfile.GFile instead.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "kwargs = dict(\n",
    "      is_training=True,\n",
    "      use_tpu=False,\n",
    "      use_bfloat16=False,\n",
    "      dropout=0.0,\n",
    "      dropatt=0.0,\n",
    "      init='normal',\n",
    "      init_range=0.1,\n",
    "      init_std=0.05,\n",
    "      clamp_len=-1)\n",
    "\n",
    "xlnet_parameters = xlnet.RunConfig(**kwargs)\n",
    "xlnet_config = xlnet.XLNetConfig(json_path='output-model/config.json')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'keys': ['n_layer',\n",
       "  'd_model',\n",
       "  'n_head',\n",
       "  'd_head',\n",
       "  'd_inner',\n",
       "  'ff_activation',\n",
       "  'untie_r',\n",
       "  'n_token'],\n",
       " 'n_layer': 12,\n",
       " 'd_model': 768,\n",
       " 'n_head': 12,\n",
       " 'd_head': 64,\n",
       " 'd_inner': 3072,\n",
       " 'ff_activation': 'gelu',\n",
       " 'untie_r': True,\n",
       " 'n_token': 32000}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xlnet_config.__dict__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32000"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xlnet_config.n_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model:\n",
    "    def __init__(\n",
    "        self,\n",
    "    ):\n",
    "        self.X = tf.placeholder(tf.int32, [None, None])\n",
    "        self.segment_ids = tf.placeholder(tf.int32, [None, None])\n",
    "        self.input_masks = tf.placeholder(tf.float32, [None, None])\n",
    "        \n",
    "        xlnet_model = xlnet.XLNetModel(\n",
    "            xlnet_config=xlnet_config,\n",
    "            run_config=xlnet_parameters,\n",
    "            input_ids=tf.transpose(self.X, [1, 0]),\n",
    "            seg_ids=tf.transpose(self.segment_ids, [1, 0]),\n",
    "            input_mask=tf.transpose(self.input_masks, [1, 0]))\n",
    "        \n",
    "        output = xlnet_model.get_sequence_output()\n",
    "        lookup_table = xlnet_model.get_embedding_table()\n",
    "        \n",
    "        with tf.variable_scope(\"model\", reuse=tf.AUTO_REUSE):\n",
    "            with tf.variable_scope('lm_loss'):\n",
    "                softmax_w = lookup_table\n",
    "                softmax_b = tf.get_variable('bias', [xlnet_config.n_token], dtype=output.dtype,\n",
    "                                    initializer=tf.zeros_initializer())\n",
    "                logits = tf.einsum('ibd,nd->ibn', output, lookup_table) + softmax_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:220: The name tf.variable_scope is deprecated. Please use tf.compat.v1.variable_scope instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/xlnet.py:220: The name tf.AUTO_REUSE is deprecated. Please use tf.compat.v1.AUTO_REUSE instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:453: The name tf.logging.info is deprecated. Please use tf.compat.v1.logging.info instead.\n",
      "\n",
      "INFO:tensorflow:memory input None\n",
      "INFO:tensorflow:Use float type <dtype: 'float32'>\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:460: The name tf.get_variable is deprecated. Please use tf.compat.v1.get_variable instead.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:535: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.dropout instead.\n",
      "WARNING:tensorflow:From /home/husein/.local/lib/python3.6/site-packages/tensorflow_core/python/layers/core.py:271: Layer.apply (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `layer.__call__` method instead.\n",
      "WARNING:tensorflow:\n",
      "The TensorFlow contrib module will not be included in TensorFlow 2.0.\n",
      "For more information, please see:\n",
      "  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md\n",
      "  * https://github.com/tensorflow/addons\n",
      "  * https://github.com/tensorflow/io (for I/O related ops)\n",
      "If you depend on functionality not listed there, please file an issue.\n",
      "\n",
      "WARNING:tensorflow:From /home/husein/xlnet/modeling.py:67: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use keras.layers.Dense instead.\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "sess = tf.InteractiveSession()\n",
    "model = Model()\n",
    "\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "import re\n",
    "\n",
    "def get_assignment_map_from_checkpoint(tvars, init_checkpoint):\n",
    "    \"\"\"Compute the union of the current variables and checkpoint variables.\"\"\"\n",
    "    assignment_map = {}\n",
    "    initialized_variable_names = {}\n",
    "\n",
    "    name_to_variable = collections.OrderedDict()\n",
    "    for var in tvars:\n",
    "        name = var.name\n",
    "        m = re.match('^(.*):\\\\d+$', name)\n",
    "        if m is not None:\n",
    "            name = m.group(1)\n",
    "        name_to_variable[name] = var\n",
    "\n",
    "    init_vars = tf.train.list_variables(init_checkpoint)\n",
    "\n",
    "    assignment_map = collections.OrderedDict()\n",
    "    for x in init_vars:\n",
    "        (name, var) = (x[0], x[1])\n",
    "        if name not in name_to_variable:\n",
    "            continue\n",
    "        assignment_map[name] = name_to_variable[name]\n",
    "        initialized_variable_names[name] = 1\n",
    "        initialized_variable_names[name + ':0'] = 1\n",
    "\n",
    "    return (assignment_map, initialized_variable_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "OrderedDict([('model/transformer/r_w_bias',\n",
       "              <tf.Variable 'model/transformer/r_w_bias:0' shape=(12, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/r_r_bias',\n",
       "              <tf.Variable 'model/transformer/r_r_bias:0' shape=(12, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/word_embedding/lookup_table',\n",
       "              <tf.Variable 'model/transformer/word_embedding/lookup_table:0' shape=(32000, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/r_s_bias',\n",
       "              <tf.Variable 'model/transformer/r_s_bias:0' shape=(12, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/seg_embed',\n",
       "              <tf.Variable 'model/transformer/seg_embed:0' shape=(12, 2, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_0/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_0/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_0/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_1/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_1/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_1/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_2/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_2/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_2/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_3/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_3/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_3/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_4/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_4/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_4/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_5/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_5/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_5/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_6/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_6/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_6/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_7/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_7/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_7/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_8/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_8/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_8/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_9/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_9/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_9/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_10/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_10/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_10/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/q/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/q/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/k/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/k/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/v/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/v/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/r/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/r/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/o/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/o/kernel:0' shape=(768, 12, 64) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/rel_attn/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_11/rel_attn/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/layer_1/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/layer_1/kernel:0' shape=(768, 3072) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/layer_1/bias',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/layer_1/bias:0' shape=(3072,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/layer_2/kernel',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/layer_2/kernel:0' shape=(3072, 768) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/layer_2/bias',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/layer_2/bias:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/LayerNorm/beta',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/transformer/layer_11/ff/LayerNorm/gamma',\n",
       "              <tf.Variable 'model/transformer/layer_11/ff/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>),\n",
       "             ('model/lm_loss/bias',\n",
       "              <tf.Variable 'model/lm_loss/bias:0' shape=(32000,) dtype=float32_ref>)])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tvars = tf.trainable_variables()\n",
    "\n",
    "assignment_map = {}\n",
    "initialized_variable_names = {}\n",
    "\n",
    "name_to_variable = collections.OrderedDict()\n",
    "for var in tvars:\n",
    "    name = var.name\n",
    "    m = re.match('^(.*):\\\\d+$', name)\n",
    "    if m is not None:\n",
    "        name = m.group(1)\n",
    "    name_to_variable[name] = var\n",
    "    \n",
    "name_to_variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('beta1_power', []),\n",
       " ('beta2_power', []),\n",
       " ('global_step', []),\n",
       " ('model/lm_loss/bias', [32000]),\n",
       " ('model/lm_loss/bias/Adam', [32000]),\n",
       " ('model/lm_loss/bias/Adam_1', [32000]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_0/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_0/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_0/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_0/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_0/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_0/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_0/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_0/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_0/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_0/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_0/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_0/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_0/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_0/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_0/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_1/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_1/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_1/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_1/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_1/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_1/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_1/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_1/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_1/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_1/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_1/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_1/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_1/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_1/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_1/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_10/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_10/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_10/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_10/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_10/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_10/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_10/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_10/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_10/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_10/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_10/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_10/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_10/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_10/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_10/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_11/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_11/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_11/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_11/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_11/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_11/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_11/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_11/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_11/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_11/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_11/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_11/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_11/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_11/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_11/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_2/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_2/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_2/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_2/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_2/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_2/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_2/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_2/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_2/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_2/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_2/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_2/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_2/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_2/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_2/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_3/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_3/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_3/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_3/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_3/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_3/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_3/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_3/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_3/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_3/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_3/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_3/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_3/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_3/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_3/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_4/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_4/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_4/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_4/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_4/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_4/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_4/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_4/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_4/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_4/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_4/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_4/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_4/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_4/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_4/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_5/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_5/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_5/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_5/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_5/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_5/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_5/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_5/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_5/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_5/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_5/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_5/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_5/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_5/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_5/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_6/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_6/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_6/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_6/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_6/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_6/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_6/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_6/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_6/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_6/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_6/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_6/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_6/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_6/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_6/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_7/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_7/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_7/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_7/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_7/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_7/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_7/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_7/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_7/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_7/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_7/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_7/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_7/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_7/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_7/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_8/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_8/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_8/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_8/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_8/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_8/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_8/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_8/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_8/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_8/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_8/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_8/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_8/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_8/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_8/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_9/ff/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_9/ff/layer_1/bias', [3072]),\n",
       " ('model/transformer/layer_9/ff/layer_1/bias/Adam', [3072]),\n",
       " ('model/transformer/layer_9/ff/layer_1/bias/Adam_1', [3072]),\n",
       " ('model/transformer/layer_9/ff/layer_1/kernel', [768, 3072]),\n",
       " ('model/transformer/layer_9/ff/layer_1/kernel/Adam', [768, 3072]),\n",
       " ('model/transformer/layer_9/ff/layer_1/kernel/Adam_1', [768, 3072]),\n",
       " ('model/transformer/layer_9/ff/layer_2/bias', [768]),\n",
       " ('model/transformer/layer_9/ff/layer_2/bias/Adam', [768]),\n",
       " ('model/transformer/layer_9/ff/layer_2/bias/Adam_1', [768]),\n",
       " ('model/transformer/layer_9/ff/layer_2/kernel', [3072, 768]),\n",
       " ('model/transformer/layer_9/ff/layer_2/kernel/Adam', [3072, 768]),\n",
       " ('model/transformer/layer_9/ff/layer_2/kernel/Adam_1', [3072, 768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/beta', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/beta/Adam', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/beta/Adam_1', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/gamma', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/gamma/Adam', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/LayerNorm/gamma/Adam_1', [768]),\n",
       " ('model/transformer/layer_9/rel_attn/k/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/k/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/k/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/o/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/o/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/o/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/q/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/q/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/q/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/r/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/r/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/r/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/v/kernel', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/v/kernel/Adam', [768, 12, 64]),\n",
       " ('model/transformer/layer_9/rel_attn/v/kernel/Adam_1', [768, 12, 64]),\n",
       " ('model/transformer/mask_emb/mask_emb', [1, 1, 768]),\n",
       " ('model/transformer/mask_emb/mask_emb/Adam', [1, 1, 768]),\n",
       " ('model/transformer/mask_emb/mask_emb/Adam_1', [1, 1, 768]),\n",
       " ('model/transformer/r_r_bias', [12, 12, 64]),\n",
       " ('model/transformer/r_r_bias/Adam', [12, 12, 64]),\n",
       " ('model/transformer/r_r_bias/Adam_1', [12, 12, 64]),\n",
       " ('model/transformer/r_s_bias', [12, 12, 64]),\n",
       " ('model/transformer/r_s_bias/Adam', [12, 12, 64]),\n",
       " ('model/transformer/r_s_bias/Adam_1', [12, 12, 64]),\n",
       " ('model/transformer/r_w_bias', [12, 12, 64]),\n",
       " ('model/transformer/r_w_bias/Adam', [12, 12, 64]),\n",
       " ('model/transformer/r_w_bias/Adam_1', [12, 12, 64]),\n",
       " ('model/transformer/seg_embed', [12, 2, 12, 64]),\n",
       " ('model/transformer/seg_embed/Adam', [12, 2, 12, 64]),\n",
       " ('model/transformer/seg_embed/Adam_1', [12, 2, 12, 64]),\n",
       " ('model/transformer/word_embedding/lookup_table', [32000, 768]),\n",
       " ('model/transformer/word_embedding/lookup_table/Adam', [32000, 768]),\n",
       " ('model/transformer/word_embedding/lookup_table/Adam_1', [32000, 768])]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "checkpoint = 'output-model/model.ckpt-300000'\n",
    "tf.train.list_variables(checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, \n",
    "                                                                                checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from output-model/model.ckpt-300000\n"
     ]
    }
   ],
   "source": [
    "saver = tf.train.Saver(var_list = assignment_map)\n",
    "saver.restore(sess, checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'xlnet-base/model.ckpt'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "saver = tf.train.Saver(tf.trainable_variables())\n",
    "saver.save(sess, 'xlnet-base/model.ckpt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
